-
Notifications
You must be signed in to change notification settings - Fork 137
Refactor and update QR Op #1518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor and update QR Op #1518
Conversation
5bc044c
to
be949cd
Compare
71639ef
to
a6d6c11
Compare
a6d6c11
to
112f6fd
Compare
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (49.75%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1518 +/- ##
==========================================
- Coverage 81.85% 81.45% -0.41%
==========================================
Files 230 232 +2
Lines 52522 53027 +505
Branches 9345 9422 +77
==========================================
+ Hits 42992 43192 +200
- Misses 7095 7390 +295
- Partials 2435 2445 +10
🚀 New features to boost your workflow:
|
in_dtype = config.floatX if integer_input else dtype | ||
|
||
@numba_njit(cache=False) | ||
def qr(a): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to worry about a
not being F-contiguous
like other lapack/blas stuff?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, it's done in the lower level functions
Description
This PR updates the QRFull Op, adding static shape checking, infer_shape, and destroy_map. It also optimizes the perform method for the C backend, and tries to improve the gradient graph by checking static shapes (to avoid an ifelse).
I renamed it to QR, because I don't know what was Full about the old one. I also moved it from the numpy implementation to scipy, which gives us all the usual benefits (inplace, etc). I also went ahead and unpacked the scipy wrapper and used the LAPACK functions directly. This will give us better error handling (that is to say, none -- it should eventually return a matrix of NaN on failure) and some performance boost by caching workspace requirements.
Still a WIP, because it breaks everything by moving QR from nlinalg to slinalg. I thought about using this as an opportunity to finally eliminate this distinction and go to a more logical organization (linalg/decomposition/qr.py), but then decided against it for now. Needs discussion.
Related Issue
infer_shape
method toQRFull
#1511Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1518.org.readthedocs.build/en/1518/